{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Siren Exploration\n",
    "\n",
    "This is a colab to explore properties of the Siren MLP, proposed in our work [Implicit Neural Activations with Periodic Activation Functions](https://vsitzmann.github.io/siren).\n",
    "\n",
    "\n",
    "We will first implement a streamlined version of Siren for fast experimentation. This lacks the code to easily do baseline comparisons - please refer to the main code for that - but will greatly simplify the code!\n",
    "\n",
    "**Make sure that you have enabled the GPU under Edit -> Notebook Settings!**\n",
    "\n",
    "We will then reproduce the following results from the paper: \n",
    "* [Fitting an image](#section_1)\n",
    "* [Fitting an audio signal](#section_2)\n",
    "* [Solving Poisson's equation](#section_3)\n",
    "* [Initialization scheme & distribution of activations](#activations)\n",
    "* [Distribution of activations is shift-invariant](#shift_invariance)\n",
    "\n",
    "We will also explore Siren's [behavior outside of the training range](#out_of_range).\n",
    "\n",
    "Let's go! First, some imports, and a function to quickly generate coordinate grids."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import os\n",
    "\n",
    "from PIL import Image\n",
    "from torchvision.transforms import Resize, Compose, ToTensor, Normalize\n",
    "import numpy as np\n",
    "import skimage\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import time\n",
    "\n",
    "def get_mgrid(sidelen, dim=2):\n",
    "    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.\n",
    "    sidelen: int\n",
    "    dim: int'''\n",
    "    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])\n",
    "    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)\n",
    "    mgrid = mgrid.reshape(-1, dim)\n",
    "    return mgrid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we code up the sine layer, which will be the basic building block of SIREN. This is a much more concise implementation than the one in the main code, as here, we aren't concerned with the baseline comparisons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SineLayer(nn.Module):\n",
    "    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.\n",
    "    \n",
    "    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the \n",
    "    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a \n",
    "    # hyperparameter.\n",
    "    \n",
    "    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of \n",
    "    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)\n",
    "    \n",
    "    def __init__(self, in_features, out_features, bias=True,\n",
    "                 is_first=False, omega_0=30):\n",
    "        super().__init__()\n",
    "        self.omega_0 = omega_0\n",
    "        self.is_first = is_first\n",
    "        \n",
    "        self.in_features = in_features\n",
    "        self.linear = nn.Linear(in_features, out_features, bias=bias)\n",
    "        \n",
    "        self.init_weights()\n",
    "    \n",
    "    def init_weights(self):\n",
    "        with torch.no_grad():\n",
    "            if self.is_first:\n",
    "                self.linear.weight.uniform_(-1 / self.in_features, \n",
    "                                             1 / self.in_features)      \n",
    "            else:\n",
    "                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n",
    "                                             np.sqrt(6 / self.in_features) / self.omega_0)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        return torch.sin(self.omega_0 * self.linear(input))\n",
    "    \n",
    "    def forward_with_intermediate(self, input): \n",
    "        # For visualization of activation distributions\n",
    "        intermediate = self.omega_0 * self.linear(input)\n",
    "        return torch.sin(intermediate), intermediate\n",
    "    \n",
    "    \n",
    "class Siren(nn.Module):\n",
    "    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, \n",
    "                 first_omega_0=30, hidden_omega_0=30.):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.net = []\n",
    "        self.net.append(SineLayer(in_features, hidden_features, \n",
    "                                  is_first=True, omega_0=first_omega_0))\n",
    "\n",
    "        for i in range(hidden_layers):\n",
    "            self.net.append(SineLayer(hidden_features, hidden_features, \n",
    "                                      is_first=False, omega_0=hidden_omega_0))\n",
    "\n",
    "        if outermost_linear:\n",
    "            final_linear = nn.Linear(hidden_features, out_features)\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, \n",
    "                                              np.sqrt(6 / hidden_features) / hidden_omega_0)\n",
    "                \n",
    "            self.net.append(final_linear)\n",
    "        else:\n",
    "            self.net.append(SineLayer(hidden_features, out_features, \n",
    "                                      is_first=False, omega_0=hidden_omega_0))\n",
    "        \n",
    "        self.net = nn.Sequential(*self.net)\n",
    "    \n",
    "    def forward(self, coords):\n",
    "        coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input\n",
    "        output = self.net(coords)\n",
    "        return output, coords        \n",
    "\n",
    "    def forward_with_activations(self, coords, retain_grad=False):\n",
    "        '''Returns not only model output, but also intermediate activations.\n",
    "        Only used for visualizing activations later!'''\n",
    "        activations = OrderedDict()\n",
    "\n",
    "        activation_count = 0\n",
    "        x = coords.clone().detach().requires_grad_(True)\n",
    "        activations['input'] = x\n",
    "        for i, layer in enumerate(self.net):\n",
    "            if isinstance(layer, SineLayer):\n",
    "                x, intermed = layer.forward_with_intermediate(x)\n",
    "                \n",
    "                if retain_grad:\n",
    "                    x.retain_grad()\n",
    "                    intermed.retain_grad()\n",
    "                    \n",
    "                activations['_'.join((str(layer.__class__), \"%d\" % activation_count))] = intermed\n",
    "                activation_count += 1\n",
    "            else: \n",
    "                x = layer(x)\n",
    "                \n",
    "                if retain_grad:\n",
    "                    x.retain_grad()\n",
    "                    \n",
    "            activations['_'.join((str(layer.__class__), \"%d\" % activation_count))] = x\n",
    "            activation_count += 1\n",
    "\n",
    "        return activations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally, differential operators that allow us to leverage autograd to compute gradients, the laplacian, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def laplace(y, x):\n",
    "    grad = gradient(y, x)\n",
    "    return divergence(grad, x)\n",
    "\n",
    "\n",
    "def divergence(y, x):\n",
    "    div = 0.\n",
    "    for i in range(y.shape[-1]):\n",
    "        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]\n",
    "    return div\n",
    "\n",
    "\n",
    "def gradient(y, x, grad_outputs=None):\n",
    "    if grad_outputs is None:\n",
    "        grad_outputs = torch.ones_like(y)\n",
    "    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]\n",
    "    return grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiments\n",
    "\n",
    "For the image fitting and poisson experiments, we'll use the classic cameraman image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cameraman_tensor(sidelength):\n",
    "    img = Image.fromarray(skimage.data.camera())        \n",
    "    transform = Compose([\n",
    "        Resize(sidelength),\n",
    "        ToTensor(),\n",
    "        Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))\n",
    "    ])\n",
    "    img = transform(img)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='section_1'></a>\n",
    "## Fitting an image\n",
    "\n",
    "First, let's simply fit that image!\n",
    "\n",
    "We seek to parameterize a greyscale image $f(x)$ with pixel coordinates $x$ with a SIREN $\\Phi(x)$.\n",
    "\n",
    "That is we seek the function $\\Phi$ such that:\n",
    "$\\mathcal{L}=\\int_{\\Omega} \\lVert \\Phi(\\mathbf{x}) - f(\\mathbf{x}) \\rVert\\mathrm{d}\\mathbf{x}$\n",
    " is minimized, in which $\\Omega$ is the domain of the image. \n",
    " \n",
    "We write a little datast that does nothing except calculating per-pixel coordinates:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageFitting(Dataset):\n",
    "    def __init__(self, sidelength):\n",
    "        super().__init__()\n",
    "        img = get_cameraman_tensor(sidelength)\n",
    "        self.pixels = img.permute(1, 2, 0).view(-1, 1)\n",
    "        self.coords = get_mgrid(sidelength, 2)\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    def __getitem__(self, idx):    \n",
    "        if idx > 0: raise IndexError\n",
    "            \n",
    "        return self.coords, self.pixels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's instantiate the dataset and our Siren. As pixel coordinates are 2D, the siren has 2 input features, and since the image is grayscale, it has one output channel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cameraman = ImageFitting(256)\n",
    "dataloader = DataLoader(cameraman, batch_size=1, pin_memory=True, num_workers=0)\n",
    "\n",
    "img_siren = Siren(in_features=2, out_features=1, hidden_features=256, \n",
    "                  hidden_layers=3, outermost_linear=True)\n",
    "img_siren.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now fit Siren in a simple training loop. Within only hundreds of iterations, the image and its gradients are approximated well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_steps = 500 # Since the whole image is our dataset, this just means 500 gradient descent steps.\n",
    "steps_til_summary = 10\n",
    "\n",
    "optim = torch.optim.Adam(lr=1e-4, params=img_siren.parameters())\n",
    "\n",
    "model_input, ground_truth = next(iter(dataloader))\n",
    "model_input, ground_truth = model_input.cuda(), ground_truth.cuda()\n",
    "\n",
    "for step in range(total_steps):\n",
    "    model_output, coords = img_siren(model_input)    \n",
    "    loss = ((model_output - ground_truth)**2).mean()\n",
    "    \n",
    "    if not step % steps_til_summary:\n",
    "        print(\"Step %d, Total loss %0.6f\" % (step, loss))\n",
    "        img_grad = gradient(model_output, coords)\n",
    "        img_laplacian = laplace(model_output, coords)\n",
    "\n",
    "        fig, axes = plt.subplots(1,3, figsize=(18,6))\n",
    "        axes[0].imshow(model_output.cpu().view(256,256).detach().numpy())\n",
    "        axes[1].imshow(img_grad.norm(dim=-1).cpu().view(256,256).detach().numpy())\n",
    "        axes[2].imshow(img_laplacian.cpu().view(256,256).detach().numpy())\n",
    "        plt.show()\n",
    "\n",
    "    optim.zero_grad()\n",
    "    loss.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='out_of_range'></a>\n",
    "## Case study: Siren periodicity & out-of-range behavior\n",
    "\n",
    "It is known that the sum of two periodic signals is itself periodic with a period that is equal to the least common multiple of the periods of the two summands, if and only if the two periods are rational multiples of each other. If the ratio of the two periods is irrational, then their sum will *not* be periodic itself.\n",
    "\n",
    "Due to the floating-point representation in neural network libraries, this case cannot occur in practice, and all functions parameterized by Siren indeed have to be periodic.\n",
    "\n",
    "Yet, the period of the resulting function may in practice be several orders of magnitudes larger than the period of each Siren neuron!\n",
    "\n",
    "Let's test this with two sines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    coords = get_mgrid(2**10, 1) * 5 * np.pi\n",
    "    \n",
    "    sin_1 = torch.sin(coords)\n",
    "    sin_2 = torch.sin(coords * 2)\n",
    "    sum = sin_1 + sin_2\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(16,2))\n",
    "    ax.plot(coords, sum)\n",
    "    ax.plot(coords, sin_1)\n",
    "    ax.plot(coords, sin_2)\n",
    "    plt.title(\"Rational multiple\")\n",
    "    plt.show()\n",
    "    \n",
    "    sin_1 = torch.sin(coords)\n",
    "    sin_2 = torch.sin(coords * np.pi)\n",
    "    sum = sin_1 + sin_2\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(16,2))\n",
    "    ax.plot(coords, sum)\n",
    "    ax.plot(coords, sin_1)\n",
    "    ax.plot(coords, sin_2)\n",
    "    plt.title(\"Pseudo-irrational multiple\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Though the second plot looks periodic, closer inspection shows that the period of the blue line is indeed larger than the range we're sampling here. \n",
    "\n",
    "Let's take a look at what the Siren we just trained looks like outside its training domain!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    out_of_range_coords = get_mgrid(1024, 2) * 50\n",
    "    model_out, _ = img_siren(out_of_range_coords.cuda())\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(16,16))\n",
    "    ax.imshow(model_out.cpu().view(1024,1024).numpy())\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Though there is some self-similarity, the signal is not repeated on this range of (-50, 50)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fitting an audio signal\n",
    "<a id='section_2'></a>\n",
    "\n",
    "Here, we'll use Siren to parameterize an audio signal - i.e., we seek to parameterize an audio waverform $f(t)$  at time points $t$ by a SIREN $\\Phi$.\n",
    "\n",
    "That is we seek the function $\\Phi$ such that:  $\\mathcal{L}\\int_\\Omega \\lVert \\Phi(t) - f(t) \\rVert \\mathrm{d}t$  is minimized, in which  $\\Omega$  is the domain of the waveform.\n",
    "\n",
    "For the audio, we'll use the bach sonata:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.io.wavfile as wavfile\n",
    "import io\n",
    "from IPython.display import Audio\n",
    "\n",
    "if not os.path.exists('gt_bach.wav'):\n",
    "    !wget https://vsitzmann.github.io/siren/img/audio/gt_bach.wav"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's build a little dataset that computes coordinates for audio files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AudioFile(torch.utils.data.Dataset):\n",
    "    def __init__(self, filename):\n",
    "        self.rate, self.data = wavfile.read(filename)\n",
    "        self.data = self.data.astype(np.float32)\n",
    "        self.timepoints = get_mgrid(len(self.data), 1)\n",
    "\n",
    "    def get_num_samples(self):\n",
    "        return self.timepoints.shape[0]\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        amplitude = self.data\n",
    "        scale = np.max(np.abs(amplitude))\n",
    "        amplitude = (amplitude / scale)\n",
    "        amplitude = torch.Tensor(amplitude).view(-1, 1)\n",
    "        return self.timepoints, amplitude"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's instantiate the Siren. As this audio signal has a much higer spatial frequency on the range of -1 to 1, we increase the $\\omega_0$ in the first layer of siren."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bach_audio = AudioFile('gt_bach.wav')\n",
    "\n",
    "dataloader = DataLoader(bach_audio, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)\n",
    "\n",
    "# Note that we increase the frequency of the first layer to match the higher frequencies of the\n",
    "# audio signal. Equivalently, we could also increase the range of the input coordinates.\n",
    "audio_siren = Siren(in_features=1, out_features=1, hidden_features=256, \n",
    "                    hidden_layers=3, first_omega_0=3000, outermost_linear=True)\n",
    "audio_siren.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a quick listen to ground truth:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rate, _ = wavfile.read('gt_bach.wav')\n",
    "\n",
    "model_input, ground_truth = next(iter(dataloader))\n",
    "Audio(ground_truth.squeeze().numpy(),rate=rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now fit the Siren to this signal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_steps = 1000 \n",
    "steps_til_summary = 100\n",
    "\n",
    "optim = torch.optim.Adam(lr=1e-4, params=audio_siren.parameters())\n",
    "\n",
    "model_input, ground_truth = next(iter(dataloader))\n",
    "model_input, ground_truth = model_input.cuda(), ground_truth.cuda()\n",
    "\n",
    "for step in range(total_steps):\n",
    "    model_output, coords = audio_siren(model_input)    \n",
    "    loss = F.mse_loss(model_output, ground_truth)\n",
    "    \n",
    "    if not step % steps_til_summary:\n",
    "        print(\"Step %d, Total loss %0.6f\" % (step, loss))\n",
    "    \n",
    "        fig, axes = plt.subplots(1,2)\n",
    "        axes[0].plot(coords.squeeze().detach().cpu().numpy(),model_output.squeeze().detach().cpu().numpy())\n",
    "        axes[1].plot(coords.squeeze().detach().cpu().numpy(),ground_truth.squeeze().detach().cpu().numpy())\n",
    "        plt.show()\n",
    "\n",
    "    optim.zero_grad()\n",
    "    loss.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "final_model_output, coords = audio_siren(model_input)\n",
    "Audio(final_model_output.cpu().detach().squeeze().numpy(),rate=rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "As we can see, within few iterations, Siren has approximated the audio signal very well!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='section_3'></a>\n",
    "## Solving Poisson's equation\n",
    "\n",
    "Now, let's make it a bit harder. Let's say we want to reconstruct an image but we only have access to its gradients!\n",
    "\n",
    "That is, we now seek the function $\\Phi$ such that:\n",
    "$\\mathcal{L}=\\int_{\\Omega} \\lVert \\nabla\\Phi(\\mathbf{x}) - \\nabla f(\\mathbf{x}) \\rVert\\mathrm{d}\\mathbf{x}$\n",
    " is minimized, in which $\\Omega$ is the domain of the image. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import scipy.ndimage\n",
    "    \n",
    "class PoissonEqn(Dataset):\n",
    "    def __init__(self, sidelength):\n",
    "        super().__init__()\n",
    "        img = get_cameraman_tensor(sidelength)\n",
    "        \n",
    "        # Compute gradient and laplacian       \n",
    "        grads_x = scipy.ndimage.sobel(img.numpy(), axis=1).squeeze(0)[..., None]\n",
    "        grads_y = scipy.ndimage.sobel(img.numpy(), axis=2).squeeze(0)[..., None]\n",
    "        grads_x, grads_y = torch.from_numpy(grads_x), torch.from_numpy(grads_y)\n",
    "                \n",
    "        self.grads = torch.stack((grads_x, grads_y), dim=-1).view(-1, 2)\n",
    "        self.laplace = scipy.ndimage.laplace(img.numpy()).squeeze(0)[..., None]\n",
    "        self.laplace = torch.from_numpy(self.laplace)\n",
    "        \n",
    "        self.pixels = img.permute(1, 2, 0).view(-1, 1)\n",
    "        self.coords = get_mgrid(sidelength, 2)\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.coords, {'pixels':self.pixels, 'grads':self.grads, 'laplace':self.laplace}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### Instantiate SIREN model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "cameraman_poisson = PoissonEqn(128)\n",
    "dataloader = DataLoader(cameraman_poisson, batch_size=1, pin_memory=True, num_workers=0)\n",
    "\n",
    "poisson_siren = Siren(in_features=2, out_features=1, hidden_features=256, \n",
    "                      hidden_layers=3, outermost_linear=True)\n",
    "poisson_siren.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### Define the loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def gradients_mse(model_output, coords, gt_gradients):\n",
    "    # compute gradients on the model\n",
    "    gradients = gradient(model_output, coords)\n",
    "    # compare them with the ground-truth\n",
    "    gradients_loss = torch.mean((gradients - gt_gradients).pow(2).sum(-1))\n",
    "    return gradients_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "total_steps = 1000\n",
    "steps_til_summary = 10\n",
    "\n",
    "optim = torch.optim.Adam(lr=1e-4, params=poisson_siren.parameters())\n",
    "\n",
    "model_input, gt = next(iter(dataloader))\n",
    "gt = {key: value.cuda() for key, value in gt.items()}\n",
    "model_input = model_input.cuda()\n",
    "\n",
    "for step in range(total_steps):\n",
    "    start_time = time.time()\n",
    "\n",
    "    model_output, coords = poisson_siren(model_input)\n",
    "    train_loss = gradients_mse(model_output, coords, gt['grads'])\n",
    "\n",
    "    if not step % steps_til_summary:\n",
    "        print(\"Step %d, Total loss %0.6f, iteration time %0.6f\" % (step, train_loss, time.time() - start_time))\n",
    "\n",
    "        img_grad = gradient(model_output, coords)\n",
    "        img_laplacian = laplace(model_output, coords)\n",
    "\n",
    "        fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
    "        axes[0].imshow(model_output.cpu().view(128,128).detach().numpy())\n",
    "        axes[1].imshow(img_grad.cpu().norm(dim=-1).view(128,128).detach().numpy())\n",
    "        axes[2].imshow(img_laplacian.cpu().view(128,128).detach().numpy())\n",
    "        plt.show()\n",
    "        \n",
    "    optim.zero_grad()\n",
    "    train_loss.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "<a id='activations'></a>\n",
    "## Initialization scheme & distribution of activations\n",
    "\n",
    "We now reproduce the empirical result on the distribution of activations, and will thereafter show empirically that the distribution of activations is shift-invariant as well! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import matplotlib\n",
    "import numpy.fft as fft\n",
    "import scipy.stats as stats\n",
    "\n",
    "\n",
    "def eformat(f, prec, exp_digits):\n",
    "    s = \"%.*e\"%(prec, f)\n",
    "    mantissa, exp = s.split('e')\n",
    "    # add 1 to digits as 1 is taken by sign +/-\n",
    "    return \"%se%+0*d\"%(mantissa, exp_digits+1, int(exp))\n",
    "\n",
    "def format_x_ticks(x, pos):\n",
    "    \"\"\"Format odd tick positions\n",
    "    \"\"\"\n",
    "    return eformat(x, 0, 1)\n",
    "\n",
    "def format_y_ticks(x, pos):\n",
    "    \"\"\"Format odd tick positions\n",
    "    \"\"\"\n",
    "    return eformat(x, 0, 1)\n",
    "\n",
    "def get_spectrum(activations):\n",
    "    n = activations.shape[0]\n",
    "\n",
    "    spectrum = fft.fft(activations.numpy().astype(np.double).sum(axis=-1), axis=0)[:n//2]\n",
    "    spectrum = np.abs(spectrum)\n",
    "\n",
    "    max_freq = 100                \n",
    "    freq = fft.fftfreq(n, 2./n)[:n//2]\n",
    "    return freq[:max_freq], spectrum[:max_freq]\n",
    "\n",
    "\n",
    "def plot_all_activations_and_grads(activations):\n",
    "    num_cols = 4\n",
    "    num_rows = len(activations)\n",
    "    \n",
    "    fig_width = 5.5\n",
    "    fig_height = num_rows/num_cols*fig_width\n",
    "    fig_height = 9\n",
    "    \n",
    "    fontsize = 5\n",
    "        \n",
    "    fig, axs = plt.subplots(num_rows, num_cols, gridspec_kw={'hspace': 0.3, 'wspace': 0.2},\n",
    "                            figsize=(fig_width, fig_height), dpi=300)\n",
    "    \n",
    "    axs[0][0].set_title(\"Activation Distribution\", fontsize=7, fontfamily='serif', pad=5.)\n",
    "    axs[0][1].set_title(\"Activation Spectrum\", fontsize=7, fontfamily='serif', pad=5.)\n",
    "    axs[0][2].set_title(\"Gradient Distribution\", fontsize=7, fontfamily='serif', pad=5.)\n",
    "    axs[0][3].set_title(\"Gradient Spectrum\", fontsize=7, fontfamily='serif', pad=5.)\n",
    "\n",
    "    x_formatter = matplotlib.ticker.FuncFormatter(format_x_ticks)\n",
    "    y_formatter = matplotlib.ticker.FuncFormatter(format_y_ticks)\n",
    "\n",
    "    spec_rows = []\n",
    "    for idx, (key, value) in enumerate(activations.items()):    \n",
    "        grad_value = value.grad.cpu().detach().squeeze(0)\n",
    "        flat_grad = grad_value.view(-1)\n",
    "        axs[idx][2].hist(flat_grad, bins=256, density=True)\n",
    "        \n",
    "        value = value.cpu().detach().squeeze(0) # (1, num_points, 256)\n",
    "        n = value.shape[0]\n",
    "        flat_value = value.view(-1)\n",
    "            \n",
    "        axs[idx][0].hist(flat_value, bins=256, density=True)\n",
    "                \n",
    "        if idx>1:\n",
    "            if not (idx)%2:\n",
    "                x = np.linspace(-1, 1., 500)\n",
    "                axs[idx][0].plot(x, stats.arcsine.pdf(x, -1, 2), \n",
    "                                 linestyle=':', markersize=0.4, zorder=2)\n",
    "            else:\n",
    "                mu = 0\n",
    "                variance = 1\n",
    "                sigma = np.sqrt(variance)\n",
    "                x = np.linspace(mu - 3*sigma, mu + 3*sigma, 500)\n",
    "                axs[idx][0].plot(x, stats.norm.pdf(x, mu, sigma), \n",
    "                                 linestyle=':', markersize=0.4, zorder=2)\n",
    "        \n",
    "        activ_freq, activ_spec = get_spectrum(value)\n",
    "        axs[idx][1].plot(activ_freq, activ_spec)\n",
    "        \n",
    "        grad_freq, grad_spec = get_spectrum(grad_value)\n",
    "        axs[idx][-1].plot(grad_freq, grad_spec)\n",
    "        \n",
    "        for ax in axs[idx]:\n",
    "            ax.tick_params(axis='both', which='major', direction='in',\n",
    "                                    labelsize=fontsize, pad=1., zorder=10) \n",
    "            ax.tick_params(axis='x', labelrotation=0, pad=1.5, zorder=10) \n",
    "\n",
    "            ax.xaxis.set_major_formatter(x_formatter)\n",
    "            ax.yaxis.set_major_formatter(y_formatter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model = Siren(in_features=1, hidden_features=2048, \n",
    "              hidden_layers=10, out_features=1, outermost_linear=True)\n",
    "\n",
    "input_signal = torch.linspace(-1, 1, 65536//4).view(1, 65536//4, 1)\n",
    "activations = model.forward_with_activations(input_signal, retain_grad=True)\n",
    "output = activations[next(reversed(activations))]\n",
    "\n",
    "# Compute gradients. Because we have retain_grad=True on \n",
    "# activations, each activation stores its own gradient!\n",
    "output.mean().backward()\n",
    "\n",
    "plot_all_activations_and_grads(activations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note how the activations of Siren always alternate between a standard normal distribution with standard deviation one, and an arcsine distribution. If you have a beefy computer, you can put this to the extreme and increase the number of layers - this property holds even for more than 50 layers!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='shift_invariance'></a>\n",
    "## Distribution of activations is shift-invariant\n",
    "\n",
    "One of the key properties of the periodic sine nonlinearity is that it affords a degree of shift-invariance. Consider the first layer of a Siren: You can convince yourself that this layer can easily learn to map two different coordinates to *the same set of activations*. This means that whatever layers come afterwards will apply the same function to these two sets of coordinates.\n",
    "\n",
    "Moreoever, the distribution of activations similarly are shift-invariant. Let's shift our input signal by 1000 and re-compute the activations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "input_signal = torch.linspace(-1, 1, 65536//4).view(1, 65536//4, 1) + 1000\n",
    "activations = model.forward_with_activations(input_signal, retain_grad=True)\n",
    "output = activations[next(reversed(activations))]\n",
    "\n",
    "# Compute gradients. Because we have retain_grad=True on \n",
    "# activations, each activation stores its own gradient!\n",
    "output.mean().backward()\n",
    "\n",
    "plot_all_activations_and_grads(activations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the distributions of activations didn't change at all - they are perfectly invariant to the shift."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
